{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "data = np.loadtxt('diabetes.csv', delimiter = ',', dtype=np.float32)\n",
    "x_data = torch.from_numpy(data[:,:-1])\n",
    "y_data = torch.from_numpy(data[:,[-1]])\n",
    "\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        self.liner1 = torch.nn.Linear(8,6)\n",
    "        self.liner2 = torch.nn.Linear(6,4)\n",
    "        self.liner3 = torch.nn.Linear(4,1)\n",
    "        self.sigmoid = torch.nn.Sigmoid()\n",
    "        self.ReLu = torch.nn.ReLU()\n",
    "\n",
    "    def forward(self,x):\n",
    "        x = self.ReLu(self.liner1(x))\n",
    "        x = self.ReLu(self.liner2(x))\n",
    "        x = self.sigmoid(self.liner3(x))\n",
    "        return x\n",
    "\n",
    "model = Model()\n",
    "\n",
    "criterion = torch.nn.BCELoss(reduction=\"mean\")\n",
    "optimizer = torch.optim.Adam(model.parameters(),lr = 0.01)\n",
    "\n",
    "epoch_list = []\n",
    "loss_list = []\n",
    "for epoch in range(10000):\n",
    "    y_hat = model(x_data)\n",
    "    loss = criterion(y_hat, y_data)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "    # print(epoch, loss.item())\n",
    "    epoch_list.append(epoch)\n",
    "    loss_list.append(loss.item())\n",
    "print(loss.item())\n",
    "plt.plot(epoch_list, loss_list)\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.show()\n",
    "\n",
    "test = np.loadtxt('test.csv', delimiter = ',', dtype=np.float32)\n",
    "x_test = torch.from_numpy(test[:,:-1])\n",
    "y_test = torch.from_numpy(test[:,[-1]])\n",
    "y_pred = model(x_test)\n",
    "y_pred_label = torch.where(y_pred>=0.5,torch.tensor([1.0]),torch.tensor([0.0]))\n",
    "acc = torch.eq(y_pred_label, y_test).sum().item() / y_test.size(0)\n",
    "print(\"test acc:\", acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    " \n",
    " \n",
    "# 读取原始数据，并划分训练集和测试集\n",
    "raw_data = np.loadtxt('diabetes.csv.gz', delimiter=',', dtype=np.float32)\n",
    "X = raw_data[:, :-1]\n",
    "y = raw_data[:, [-1]]\n",
    "Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3)\n",
    "Xtest = torch.from_numpy(Xtest)\n",
    "Ytest = torch.from_numpy(Ytest)\n",
    " \n",
    "# 将训练数据集进行批量处理\n",
    "# prepare dataset\n",
    " \n",
    "class DiabetesDataset(Dataset):\n",
    "    def __init__(self, data,label):\n",
    " \n",
    "        self.len = data.shape[0] # shape(多少行，多少列)\n",
    "        self.x_data = torch.from_numpy(data)\n",
    "        self.y_data = torch.from_numpy(label)\n",
    " \n",
    "    def __getitem__(self, index):\n",
    "        return self.x_data[index], self.y_data[index]\n",
    " \n",
    "    def __len__(self):\n",
    "        return self.len\n",
    " \n",
    " \n",
    "train_dataset = DiabetesDataset(Xtrain,Ytrain)\n",
    "train_loader = DataLoader(dataset = train_dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程\n",
    " \n",
    "# design model using class\n",
    " \n",
    " \n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        self.linear1 = torch.nn.Linear(8, 6)\n",
    "        self.linear2 = torch.nn.Linear(6, 4)\n",
    "        self.linear3 = torch.nn.Linear(4, 2)\n",
    "        self.linear4 = torch.nn.Linear(2, 1)\n",
    "        self.sigmoid = torch.nn.Sigmoid()\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = self.sigmoid(self.linear1(x))\n",
    "        x = self.sigmoid(self.linear2(x))\n",
    "        x = self.sigmoid(self.linear3(x))\n",
    "        x = self.sigmoid(self.linear4(x))\n",
    "        return x\n",
    " \n",
    " \n",
    "model = Model()\n",
    " \n",
    "# construct loss and optimizer\n",
    "criterion = torch.nn.BCELoss(reduction='mean')\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
    " \n",
    " \n",
    "# training cycle forward, backward, update\n",
    " \n",
    "def train(epoch):\n",
    "    train_loss = 0.0\n",
    "    count = 0\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        inputs, labels = data\n",
    "        y_pred = model(inputs)\n",
    " \n",
    "        loss = criterion(y_pred, labels)\n",
    " \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss += loss.item()\n",
    "        count = i\n",
    " \n",
    "    if epoch%2000 == 1999:\n",
    "        print(\"train loss:\", train_loss/count,end=',')\n",
    " \n",
    " \n",
    "def test():\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(Xtest)\n",
    "        y_pred_label = torch.where(y_pred>=0.5,torch.tensor([1.0]),torch.tensor([0.0]))\n",
    "        acc = torch.eq(y_pred_label, Ytest).sum().item() / Ytest.size(0)\n",
    "        print(\"test acc:\", acc)\n",
    " \n",
    "if __name__ == '__main__':\n",
    "    for epoch in range(50000):\n",
    "        train(epoch)\n",
    "        if epoch%2000==1999:\n",
    "            test()\n",
    " "
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "44143c5ac5f3ceb8e37c69c3af73325ae55d21292b2c7b54871fd886482dde4c"
  },
  "kernelspec": {
   "display_name": "Python 3.6.13 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
